# Clear memory
rm(list = ls())
gc()

# Load packages
if (!require("pacman")) install.packages("pacman")
pacman::p_load(
  data.table, janitor, magrittr, fixest, RSQLite,
  broom, DBI, readxl, lubridate, tidyverse, tidylog, 
  R.utils
)
options("tidylog.display" = NULL)

`%notin%` = Negate(`%in%`)

######################################
# crosswalk fixing VA fips codes in bea
######################################

crosswalk = fread("data/simulation_fips_crosswalk_temp.csv") %>% 
  as_tibble() %>% 
  mutate(across(everything(), ~ as.character(.x))) %>% 
  mutate(fips = str_pad(fips, 5, "left", pad = "0")) %>% 
  distinct(fips)

######################################
######################################
# for simulation
######################################
######################################

######################################
# migration
######################################

# https://www.baruch.cuny.edu/confluence/plugins/servlet/mobile#content/view/34572939
con = dbConnect(RSQLite::SQLite(), "raw/IRS/irs_migration_county.sqlite")

# read in origin-destination migration flows
# use returns to focus on workers instead of household members
migration = map_dfr(5:9, function(year){
  if (year < 9) {
    dbReadTable(con, paste0("outflow_199", year, "_9", year+1)) %>% 
        as_tibble() %>% 
        dplyr::select(origin, destination, returns) %>% 
        mutate(year = 1990 + year) %>% 
        mutate(origin = ifelse(origin == "12025", "12086", origin),
               destination = ifelse(destination == "12025", "12086", destination)) %>% 
        arrange(origin, destination)
  } else {  
    dbReadTable(con, paste0("outflow_1999_00")) %>% 
        as_tibble() %>% 
        dplyr::select(origin, destination, returns) %>% 
        mutate(year = 1990 + year) %>% 
        mutate(origin = ifelse(origin == "12025", "12086", origin),
               destination = ifelse(destination == "12025", "12086", destination)) %>% 
        arrange(origin, destination)
  }
}) %>% 
    group_by(origin, destination) %>% 
    summarise(migration = sum(returns)) %>% 
    ungroup() 

# expand to full set of fips from our biggest crosswalk above
migration = expand.grid(
    origin = crosswalk$fips,
    destination = crosswalk$fips
) %>% 
    left_join(migration, by = c("origin", "destination"))

# Insert 0s if we don't observe migration
migration[is.na(migration)] = 0

# compute migration shares
migration = migration %>% 
    group_by(origin) %>% 
    mutate(initial_pop = sum(migration, na.rm = TRUE)) %>% 
    ungroup() %>% 
    mutate(share = migration/initial_pop) %>% 
    mutate(share = ifelse(is.na(share), 0, share)) %>% 
    dplyr::select(origin, destination, share) 

# destinations where no one is going: we cannot identify the migration cost
no_shares = migration %>% 
  group_by(destination) %>% 
  filter(sum(share) == 0) %>% 
  distinct(destination) %>% 
  ungroup()

migration = migration %>% 
    filter(origin %notin% no_shares$destination) %>% 
    filter(destination %notin% no_shares$destination) %>% 
    arrange(origin, destination)

# put in matrix form
migration_matrix = migration %>% 
    pivot_wider(
        names_from = "destination",
        values_from = "share"
    )

# Leading x's
names(migration_matrix)[-1] = paste0("x", names(migration_matrix)[-1])

# identify counties with no population flows
no_pop = which(abs(rowSums(migration_matrix[,2:ncol(migration_matrix)]) - 1) > .01 )

# assign 1: they just stay there, no migration
for (county in no_pop) {
    migration_matrix[county, county] = 1
}

# Assuming migration_matrix is a square matrix
migration_matrix_identity = cbind(migration_matrix[,1], diag(nrow(migration_matrix)))

######################################
# job switching
######################################

# crosswalk to convert census codes to naics codes for easier translation to manufacturing/non-manufacturing
census_naics = read_xls("raw/CPS/industry-crosswalk-90-00-02-07-12.xls", skip = 23) %>%
    clean_names() %>%
    dplyr::select(naics = x1997_naics, ind = x1990_census) %>%
    drop_na() %>%
    mutate(naics = str_remove(naics, "[*]")) %>%
    mutate(ind = str_remove(ind, "[*]")) %>%
    separate_rows(naics, sep = ", ") %>%
    mutate(naics = str_extract(naics, "\\d+")) %>%
    mutate(ind = str_extract(ind, "\\d+")) %>%
    mutate(naics = substr(naics, 1, 3)) %>%
    mutate(ind = substr(ind, 1, 3)) %>%
    distinct() %>% 
    filter(str_length(ind) == 3 & !is.na(ind) & !is.na(naics))

job_switching = fread("raw/CPS/cps_00019.csv.gz") %>%
    clean_names() %>%
    filter(age >= 25 & age <= 65) %>%
    filter(year >= 1995 & year <= 1999) %>%
    dplyr::select(fips = statefip, serial, cpsid, cpsidp, year, month, hwtfinl, pernum, wtfinl, age, empstat, ind) %>%
    arrange(cpsidp, year, month) %>%
    mutate(ind = str_pad(ind, 3, "left", "0")) %>%
    inner_join(census_naics) %>% 
    dplyr::select(-ind) %>% 
    filter(fips != 11)

job_switching = job_switching %>% 
    dplyr::select(fips, serial, cpsid, cpsidp, year, month, hwtfinl, pernum, wtfinl, age, empstat, naics) %>%
    mutate(naics = case_when(
        empstat >= 13 ~ "Unemployed", # empstat >= 13 are all not employed
        naics == "31" | (as.numeric(naics) >= 310 & as.numeric(naics) <= 339) ~ "Manufacturing",
        TRUE ~ "Other"
    )
    ) %>% 
    arrange(cpsidp, year, month) %>%
    unite(ym, year, month, sep = "-") %>%
    mutate(ym = ymd(ym, truncated = 1)) %>%
    arrange(cpsidp, ym) %>%
    group_by(cpsidp) %>%
    slice(c(1, 5)) %>% # take first observation in cps, then the observation 1 year later based on 4-8-4 scheme
    filter(n() > 1) %>% # keep only those who show up twice 
    mutate(time = c("initial_naics", "terminal_naics")) %>%
    mutate(cpsidp = cur_group_id()) %>% 
    ungroup() %>%
    dplyr::select(fips, wtfinl, cpsidp, ym, empstat, naics, time) %>%
    ungroup() %>% 
    distinct(fips, wtfinl, cpsidp, naics, time) %>%
    group_by(fips, cpsidp) %>%
    filter(n() > 1) %>% # this keeps only people who were in both periods
    mutate(wtfinl = mean(wtfinl)) %>% 
    drop_na(naics) %>% 
    ungroup() %>% 
    filter(fips !=2 & fips != 15)

job_switching_total = job_switching %>%
  group_by(cpsidp) %>%
  filter(n() > 1) %>%
  ungroup() %>% 
  dplyr::select(cpsidp, wtfinl, naics, time) %>%
  pivot_wider(
    names_from = "time",
    values_from = "naics"
  ) %>%
  arrange(initial_naics, terminal_naics) %>%
  group_by(initial_naics, terminal_naics) %>%
  summarise(n = sum(wtfinl, na.rm = TRUE)) %>%
  group_by(initial_naics) %>%
  mutate(n = n / sum(n, na.rm = TRUE)) %>%
  ungroup() 

full_set = 
  expand.grid(
    initial_naics = c("Manufacturing", "Other", "Unemployed"),
    terminal_naics = c("Manufacturing", "Other", "Unemployed")
  ) %>% 
  distinct()

# Ensure positive labor in each county-sector if necessary
job_switching_total = job_switching_total %>% 
  full_join(full_set) %>%
  replace(is.na(.), 0) %>% 
  mutate(min_inflow = min(job_switching_total$n)) %>% # starts here
  group_by(terminal_naics) %>% 
  mutate(total_inflow = sum(n)) %>% 
  mutate(n = ifelse(total_inflow == 0, min_inflow, n)) %>% 
  group_by(initial_naics) %>%
  mutate(n = n/sum(n)) %>% 
  ungroup() %>% 
  dplyr::select(-total_inflow, -min_inflow) %>% # ends here
  arrange(initial_naics, terminal_naics) %>% 
  pivot_wider(
    names_from = "terminal_naics",
    values_from = "n"
  ) %>%
  replace(is.na(.), 0) %>%
  dplyr::select(-initial_naics) %>% 
  dplyr::select(sort(names(.)))

job_switching_total_identity = diag(nrow(job_switching_total)) %>% 
  as_tibble() 
names(job_switching_total_identity) = names(job_switching_total)

# create full migration matrix as kroncker product of migration and job switching
# rows should sum to 1
# columns should sum to > 0 to ensure the simulations run (otherwise migration costs aren't identified)
migration_out = as_tibble(kronecker(as.matrix(migration_matrix[,2:ncol(migration_matrix)]), as.matrix(job_switching_total)))
migration_out = bind_cols(as_tibble(sort(rep(migration_matrix$origin,3))), migration_out)
names(migration_out) = c("origin", paste0("x", sort( c(paste0(migration_matrix$origin, "u"), paste0(migration_matrix$origin, "m"), paste0(migration_matrix$origin, "o")))))

write.csv(migration_out$origin, "data/migration_shares_counties_total.csv")
write.csv(migration_out[,2:ncol(migration_out)], "data/migration_shares_matrix_total.csv")

# get final crosswalk: 3022 counties
crosswalk = crosswalk %>% 
  inner_join(unique(migration_out[,1]) %>% rename(fips = origin)) %>% 
  arrange(fips)

fwrite(crosswalk, "data/simulation_fips_crosswalk.csv")

migration_out_no_job = kronecker(as.matrix(migration_matrix[,2:ncol(migration_matrix)]), as.matrix(job_switching_total_identity))
migration_out_no_job = bind_cols(as_tibble(sort(rep(migration_matrix$origin,3))), migration_out_no_job)
names(migration_out_no_job) = c("origin", paste0("x", sort( c(paste0(migration_matrix$origin, "u"), paste0(migration_matrix$origin, "m"), paste0(migration_matrix$origin, "o")))))

write.csv(migration_out_no_job[,2:ncol(migration_out_no_job)], "data/migration_shares_matrix_total_no_job.csv")

migration_out_no_mig = kronecker(as.matrix(migration_matrix_identity[,2:ncol(migration_matrix)]), as.matrix(job_switching_total))
migration_out_no_mig = bind_cols(as_tibble(sort(rep(migration_matrix_identity$origin,3))), migration_out_no_mig)
names(migration_out_no_mig) = c("origin", paste0("x", sort( c(paste0(migration_matrix_identity$origin, "u"), paste0(migration_matrix_identity$origin, "m"), paste0(migration_matrix_identity$origin, "o")))))

write.csv(migration_out_no_mig[,2:ncol(migration_out_no_mig)], "data/migration_shares_matrix_total_no_mig.csv")
